import time
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F

import dgl
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset

from gcn import GCN


def evaluate(model, features, labels, mask):
    """评价函数"""
    model.eval()
    with torch.no_grad():
        logits = model(features)  # 将特征输入模型查看结果
        logits = logits[mask]  # 获取某类数据的结果（train/val/test）
        labels = labels[mask]  # 获取某类数据的真实值
        _, indices = torch.max(logits, dim=1)  # 取结果中最大值为预测值
        correct = torch.sum(indices == labels)  # 获取准确的个数
        # 返回准确率
        return correct.item() * 1.0 / len(labels)


def main(config):
    # 加载和预处理数据集
    if config.dataset == 'cora':
        data = CoraGraphDataset(raw_dir=config.data_path)
    elif config.dataset == 'citeseer':
        data = CiteseerGraphDataset(raw_dir=config.data_path)
    elif config.dataset == 'pubmed':
        data = PubmedGraphDataset(raw_dir=config.data_path)
    else:
        raise ValueError('Unknown dataset: {}'.format(config.dataset))

    g = data[0]  # 获取图
    g = g.to(config.device)  # 变量放到device上

    features = g.ndata['feat']
    labels = g.ndata['label']
    train_mask = g.ndata['train_mask']
    val_mask = g.ndata['val_mask']
    test_mask = g.ndata['test_mask']
    in_feats = features.shape[1]
    n_classes = data.num_labels
    n_edges = data.graph.number_of_edges()
    print("------数据集描述------\n \
    # 边数: {:d}\n \
    # 类别数: {:d}\n \
    # 训练集大小: {:d}\n \
    # 验证集大小: {:d}\n \
    # 测试集大小: {:d}\n".format(
    n_edges, n_classes,
    train_mask.int().sum().item(),
    val_mask.int().sum().item(),
    test_mask.int().sum().item()))
    # 是否使用节点自映射
    if config.self_loop:
        g = dgl.remove_self_loop(g)
        g = dgl.add_self_loop(g)
    # 使用节点自映射后重新计算边的数量
    n_edges = g.number_of_edges()

    # normalization归一化
    degs = g.in_degrees().float()  # 对节点的度矩阵进行归一化
    norm = torch.pow(degs, -0.5)  # 根据公式可知，此处代表D^-1/2
    norm[torch.isinf(norm)] = 0  # 归一化后一些值为无穷大，将无穷大的值设为0
    norm = norm.to(config.device)
    g.ndata['norm'] = norm.unsqueeze(1)

    # 创建GCN模型
    model = GCN(g,
                in_feats,
                config.n_hidden,
                n_classes,
                config.n_layers,
                F.relu,
                config.dropout)
    model = model.to(config.device)
    loss_fcn = torch.nn.CrossEntropyLoss()

    # 使用Adam优化器
    optimizer = torch.optim.Adam(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)

    # 训练过程
    dur = []  # 记录epoch时间
    train_loss = []  # 记录训练损失变化
    train_acc = []  # 记录训练准确率变化
    for epoch in range(config.n_epochs):
        model.train()
        t0 = time.time()
        # forward前向传播，使用交叉熵损失
        logits = model(features)
        loss = loss_fcn(logits[train_mask], labels[train_mask])
        # 梯度归零后反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # 保存训练时间
        dur.append(time.time() - t0)
        # 计算训练过程中验证集的准确率
        acc = evaluate(model, features, labels, val_mask)
        # 保存训练损失和准确率
        train_loss.append(loss.item())
        train_acc.append(acc)
        # 打印相关信息
        print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f}"
              .format(epoch+1, np.mean(dur), loss.item(), acc))

    print()
    # 计算测试集准确率
    acc = evaluate(model, features, labels, test_mask)
    print("Test accuracy {:.2%}".format(acc))
    if config.show:
        plot_train(train_loss, train_acc)


def plot_train(loss, acc):
    # 画图显示loss和acc的变化
    plt.figure(0)
    plt.title("Loss")
    epoch = list(range(1, len(loss) + 1))
    plt.plot(epoch, loss, color='red', label="Loss")
    plt.legend()
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.figure(1)
    plt.title("Accuracy")
    plt.plot(epoch, acc, color='blue', label="Accuracy")
    plt.legend()
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.show()


class Config(object):
    """网络参数配置"""
    def __init__(self):
        # 数据集路径
        self.data_path = './data'
        # 数据集(三种可选)
        self.dataset = 'cora'
        # self.dataset = 'citeseer'
        # self.dataset = 'pubmed'
        # 是否使用节点自映射(节点自环)
        self.self_loop = True
        # 使用设备 (cuda或者cpu)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # 隐藏层大小
        self.n_hidden = 16
        # 网络层数()
        self.n_layers = 1
        # 丢弃率
        self.dropout = 0.5
        # 学习率
        self.lr = 1e-2
        # 权重衰减参数
        self.weight_decay = 5e-4
        # 迭代次数
        self.n_epochs = 200
        # 是否画图显示loss和acc变化
        self.show = True


if __name__ == '__main__':
    config = Config()
    main(config)
